#include <torch/torch.h>

// CUDA forward declarations
int ChamferDistanceKernelLauncher(
    const int b, const int n,
    const float *xyz,
    const int m,
    const float *xyz2,
    float *result,
    int *result_i,
    float *result2,
    int *result2_i);

int ChamferDistanceGradKernelLauncher(
    const int b, const int n,
    const float *xyz1,
    const int m,
    const float *xyz2,
    const float *grad_dist1,
    const int *idx1,
    const float *grad_dist2,
    const int *idx2,
    float *grad_xyz1,
    float *grad_xyz2);

void chamfer_distance_forward_cuda(
    const at::Tensor xyz1,
    const at::Tensor xyz2,
    const at::Tensor dist1,
    const at::Tensor dist2,
    const at::Tensor idx1,
    const at::Tensor idx2)
{
    ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
                                  xyz2.size(1), xyz2.data<float>(),
                                  dist1.data<float>(), idx1.data<int>(),
                                  dist2.data<float>(), idx2.data<int>());
}

void chamfer_distance_backward_cuda(
    const at::Tensor xyz1,
    const at::Tensor xyz2,
    at::Tensor gradxyz1,
    at::Tensor gradxyz2,
    at::Tensor graddist1,
    at::Tensor graddist2,
    at::Tensor idx1,
    at::Tensor idx2)
{
    ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data<float>(),
                                      xyz2.size(1), xyz2.data<float>(),
                                      graddist1.data<float>(), idx1.data<int>(),
                                      graddist2.data<float>(), idx2.data<int>(),
                                      gradxyz1.data<float>(), gradxyz2.data<float>());
}

void nnsearch(
    const int b, const int n, const int m,
    const float *xyz1,
    const float *xyz2,
    float *dist,
    int *idx)
{
    for (int i = 0; i < b; i++)
    {
        for (int j = 0; j < n; j++)
        {
            const float x1 = xyz1[(i * n + j) * 3 + 0];
            const float y1 = xyz1[(i * n + j) * 3 + 1];
            const float z1 = xyz1[(i * n + j) * 3 + 2];
            double best = 0;
            int besti = 0;
            for (int k = 0; k < m; k++)
            {
                const float x2 = xyz2[(i * m + k) * 3 + 0] - x1;
                const float y2 = xyz2[(i * m + k) * 3 + 1] - y1;
                const float z2 = xyz2[(i * m + k) * 3 + 2] - z1;
                const double d = x2 * x2 + y2 * y2 + z2 * z2;
                if (k == 0 || d < best)
                {
                    best = d;
                    besti = k;
                }
            }
            dist[i * n + j] = best;
            idx[i * n + j] = besti;
        }
    }
}

void chamfer_distance_forward(
    const at::Tensor xyz1,
    const at::Tensor xyz2,
    const at::Tensor dist1,
    const at::Tensor dist2,
    const at::Tensor idx1,
    const at::Tensor idx2)
{
    const int batchsize = xyz1.size(0);
    const int n = xyz1.size(1);
    const int m = xyz2.size(1);

    const float *xyz1_data = xyz1.data<float>();
    const float *xyz2_data = xyz2.data<float>();
    float *dist1_data = dist1.data<float>();
    float *dist2_data = dist2.data<float>();
    int *idx1_data = idx1.data<int>();
    int *idx2_data = idx2.data<int>();

    nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data);
    nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data);
}

void chamfer_distance_backward(
    const at::Tensor xyz1,
    const at::Tensor xyz2,
    at::Tensor gradxyz1,
    at::Tensor gradxyz2,
    at::Tensor graddist1,
    at::Tensor graddist2,
    at::Tensor idx1,
    at::Tensor idx2)
{
    const int b = xyz1.size(0);
    const int n = xyz1.size(1);
    const int m = xyz2.size(1);

    const float *xyz1_data = xyz1.data<float>();
    const float *xyz2_data = xyz2.data<float>();
    float *gradxyz1_data = gradxyz1.data<float>();
    float *gradxyz2_data = gradxyz2.data<float>();
    float *graddist1_data = graddist1.data<float>();
    float *graddist2_data = graddist2.data<float>();
    const int *idx1_data = idx1.data<int>();
    const int *idx2_data = idx2.data<int>();

    for (int i = 0; i < b * n * 3; i++)
        gradxyz1_data[i] = 0;
    for (int i = 0; i < b * m * 3; i++)
        gradxyz2_data[i] = 0;
    for (int i = 0; i < b; i++)
    {
        for (int j = 0; j < n; j++)
        {
            const float x1 = xyz1_data[(i * n + j) * 3 + 0];
            const float y1 = xyz1_data[(i * n + j) * 3 + 1];
            const float z1 = xyz1_data[(i * n + j) * 3 + 2];
            const int j2 = idx1_data[i * n + j];

            const float x2 = xyz2_data[(i * m + j2) * 3 + 0];
            const float y2 = xyz2_data[(i * m + j2) * 3 + 1];
            const float z2 = xyz2_data[(i * m + j2) * 3 + 2];
            const float g = graddist1_data[i * n + j] * 2;

            gradxyz1_data[(i * n + j) * 3 + 0] += g * (x1 - x2);
            gradxyz1_data[(i * n + j) * 3 + 1] += g * (y1 - y2);
            gradxyz1_data[(i * n + j) * 3 + 2] += g * (z1 - z2);
            gradxyz2_data[(i * m + j2) * 3 + 0] -= (g * (x1 - x2));
            gradxyz2_data[(i * m + j2) * 3 + 1] -= (g * (y1 - y2));
            gradxyz2_data[(i * m + j2) * 3 + 2] -= (g * (z1 - z2));
        }
        for (int j = 0; j < m; j++)
        {
            const float x1 = xyz2_data[(i * m + j) * 3 + 0];
            const float y1 = xyz2_data[(i * m + j) * 3 + 1];
            const float z1 = xyz2_data[(i * m + j) * 3 + 2];
            const int j2 = idx2_data[i * m + j];
            const float x2 = xyz1_data[(i * n + j2) * 3 + 0];
            const float y2 = xyz1_data[(i * n + j2) * 3 + 1];
            const float z2 = xyz1_data[(i * n + j2) * 3 + 2];
            const float g = graddist2_data[i * m + j] * 2;
            gradxyz2_data[(i * m + j) * 3 + 0] += g * (x1 - x2);
            gradxyz2_data[(i * m + j) * 3 + 1] += g * (y1 - y2);
            gradxyz2_data[(i * m + j) * 3 + 2] += g * (z1 - z2);
            gradxyz1_data[(i * n + j2) * 3 + 0] -= (g * (x1 - x2));
            gradxyz1_data[(i * n + j2) * 3 + 1] -= (g * (y1 - y2));
            gradxyz1_data[(i * n + j2) * 3 + 2] -= (g * (z1 - z2));
        }
    }
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.def("forward", &chamfer_distance_forward, "ChamferDistance forward");
    m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)");
    m.def("backward", &chamfer_distance_backward, "ChamferDistance backward");
    m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)");
}
